{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reinforcement Learning: Deep Q-Networks\n",
    "\n",
    "If you aren't familiar with reinforcement learning, check out the previous guide on reinforcement learning for an introduction.\n",
    "\n",
    "In the previous guide we implemented the Q function as a lookup table. That worked well enough for that scenario because it had a fairly small state space. However, consider something like [DeepMind's Atari player](http://www.wired.co.uk/article/google-deepmind-atari). A state in that task is a unique configuration of pixels. All those Atari games are color, so each pixel has three values (R,G,B), and there are quite a few pixels. So there is a massive state space for all possible configurations of pixels, and we simply can't implement a lookup table encompassing all of these states - it would take up too much memory.\n",
    "\n",
    "Instead, we can learn a Q _function_ that approximately maps a set of pixel values and an action to some value. We could implement this Q function as a neural network and have it learn how to predict rewards for each action given an input state. This is the general idea behind _deep Q-learning_ (i.e. deep Q networks, or DQNs).\n",
    "\n",
    "Here we'll put together a simple DQN agent that learns how to play a simple game of catch. The agent controls a paddle at the bottom of the screen that it can move left, right, or not at all (so there are three possible action). An object falls from the top of the screen, and the agent wins if it catches it (a reward of +1). Otherwise, it loses (a reward of -1).\n",
    "\n",
    "We'll implement the game in black-and-white so that the pixels in the game can be represented as 1 or 0.\n",
    "\n",
    "Using DQNs are quite like using neural networks in ways you may be more familiar with. Here we'll take a vector that represents the screen, feed it through the network, and the network will output a distribution of values over possible actions. You can kind of think of it as a classification problem: given this input state, label it with the best action to take.\n",
    "\n",
    "For example, this is the [architecture of the Atari player](http://home.uchicago.edu/~arij/journalclub/papers/2015_Mnih_et_al.pdf):\n",
    "\n",
    "![Atari player DQN](../assets/atari_dqn.png)\n",
    "\n",
    "The scenario we're dealing with is simple enough that we don't need convolutional neural networks, but we could easily extend it in that way if we wanted (just replace our vanilla neural network with a convolutional one).\n",
    "\n",
    "Here's what our catch game will look like:\n",
    "\n",
    "![Simple catch game](../assets/catch_game.png)\n",
    "\n",
    "To start I'll present the code for the catch game itself. It's not important that you understand this code - the part we care about is the agent itself.\n",
    "\n",
    "Note that this needs to be run in the terminal in order to visualize the game."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from blessings import Terminal\n",
    "\n",
    "class Game():\n",
    "    def __init__(self, shape=(10,10)):\n",
    "        self.shape = shape\n",
    "        self.height, self.width = shape\n",
    "        self.last_row = self.height - 1\n",
    "        self.paddle_padding = 1\n",
    "        self.n_actions = 3 # left, stay, right\n",
    "        self.term = Terminal()\n",
    "        self.reset()\n",
    "\n",
    "    def reset(self):\n",
    "        # reset grid\n",
    "        self.grid = np.zeros(self.shape)\n",
    "\n",
    "        # can only move left or right (or stay)\n",
    "        # so position is only its horizontal position (col)\n",
    "        self.pos = np.random.randint(self.paddle_padding, self.width - 1 - self.paddle_padding)\n",
    "        self.set_paddle(1)\n",
    "\n",
    "        # item to catch\n",
    "        self.target = (0, np.random.randint(self.width - 1))\n",
    "        self.set_position(self.target, 1)\n",
    "\n",
    "    def move(self, action):\n",
    "        # clear previous paddle position\n",
    "        self.set_paddle(0)\n",
    "\n",
    "        # action is either -1, 0, 1,\n",
    "        # but comes in as 0, 1, 2, so subtract 1\n",
    "        action -= 1\n",
    "        self.pos = min(max(self.pos + action, self.paddle_padding), self.width - 1 - self.paddle_padding)\n",
    "\n",
    "        # set new paddle position\n",
    "        self.set_paddle(1)\n",
    "\n",
    "    def set_paddle(self, val):\n",
    "        for i in range(1 + self.paddle_padding*2):\n",
    "            pos = self.pos - self.paddle_padding + i\n",
    "            self.set_position((self.last_row, pos), val)\n",
    "\n",
    "    @property\n",
    "    def state(self):\n",
    "        return self.grid.reshape((1,-1)).copy()\n",
    "\n",
    "    def set_position(self, pos, val):\n",
    "        r, c = pos\n",
    "        self.grid[r,c] = val\n",
    "\n",
    "    def update(self):\n",
    "        r, c = self.target\n",
    "\n",
    "        self.set_position(self.target, 0)\n",
    "        self.set_paddle(1) # in case the target is on the paddle\n",
    "        self.target = (r+1, c)\n",
    "        self.set_position(self.target, 1)\n",
    "\n",
    "        # off the map, it's gone\n",
    "        if r + 1 == self.last_row:\n",
    "            # reward of 1 if collided with paddle, else -1\n",
    "            if abs(c - self.pos) <= self.paddle_padding:\n",
    "                return 1\n",
    "            else:\n",
    "                return -1\n",
    "\n",
    "        return 0\n",
    "\n",
    "    def render(self):\n",
    "        print(self.term.clear())\n",
    "        for r, row in enumerate(self.grid):\n",
    "            for c, on in enumerate(row):\n",
    "                if on:\n",
    "                    color = 235\n",
    "                else:\n",
    "                    color = 229\n",
    "\n",
    "                print(self.term.move(r, c) + self.term.on_color(color) + ' ' + self.term.normal)\n",
    "\n",
    "        # move cursor to end\n",
    "        print(self.term.move(self.height, 0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Ok, on to the agent itself. I'll present the code in full here, then explain parts in more detail."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using Theano backend.\n",
      "WARNING (theano.sandbox.cuda): The cuda backend is deprecated and will be removed in the next release (v0.10).  Please switch to the gpuarray backend. You can get more information about how to switch at this URL:\n",
      " https://github.com/Theano/Theano/wiki/Converting-to-the-new-gpu-back-end%28gpuarray%29\n",
      "\n",
      "Using gpu device 0: GeForce GTX TITAN X (CNMeM is disabled, cuDNN 5103)\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "#if using Theano with GPU\n",
    "#os.environ[\"THEANO_FLAGS\"] = \"mode=FAST_RUN,device=gpu,floatX=float32\"\n",
    "\n",
    "import random\n",
    "from keras.models import Sequential\n",
    "from keras.layers.core import Dense\n",
    "from collections import deque\n",
    "\n",
    "class Agent():\n",
    "    def __init__(self, env, explore=0.1, discount=0.9, hidden_size=100, memory_limit=5000):\n",
    "        self.env = env\n",
    "        model = Sequential()\n",
    "        model.add(Dense(hidden_size, input_shape=(env.height * env.width,), activation='relu'))\n",
    "        model.add(Dense(hidden_size, activation='relu'))\n",
    "        model.add(Dense(env.n_actions))\n",
    "        model.compile(loss='mse', optimizer='sgd')\n",
    "        self.Q = model\n",
    "\n",
    "        # experience replay:\n",
    "        # remember states to \"reflect\" on later\n",
    "        self.memory = deque([], maxlen=memory_limit)\n",
    "\n",
    "        self.explore = explore\n",
    "        self.discount = discount\n",
    "\n",
    "    def choose_action(self):\n",
    "        if np.random.rand() <= self.explore:\n",
    "            return np.random.randint(0, self.env.n_actions)\n",
    "        state = self.env.state\n",
    "        q = self.Q.predict(state)\n",
    "        return np.argmax(q[0])\n",
    "\n",
    "    def remember(self, state, action, next_state, reward):\n",
    "        # the deque object will automatically keep a fixed length\n",
    "        self.memory.append((state, action, next_state, reward))\n",
    "\n",
    "    def _prep_batch(self, batch_size):\n",
    "        if batch_size > self.memory.maxlen:\n",
    "            Warning('batch size should not be larger than max memory size. Setting batch size to memory size')\n",
    "            batch_size = self.memory.maxlen\n",
    "\n",
    "        batch_size = min(batch_size, len(self.memory))\n",
    "\n",
    "        inputs = []\n",
    "        targets = []\n",
    "\n",
    "        # prep the batch\n",
    "        # inputs are states, outputs are values over actions\n",
    "        batch = random.sample(list(self.memory), batch_size)\n",
    "        random.shuffle(batch)\n",
    "        for state, action, next_state, reward in batch:\n",
    "            inputs.append(state)\n",
    "            target = self.Q.predict(state)[0]\n",
    "\n",
    "            # debug, \"this should never happen\"\n",
    "            assert not np.array_equal(state, next_state)\n",
    "\n",
    "            # non-zero reward indicates terminal state\n",
    "            if reward:\n",
    "                target[action] = reward\n",
    "            else:\n",
    "                # reward + gamma * max_a' Q(s', a')\n",
    "                Q_sa = np.max(self.Q.predict(next_state)[0])\n",
    "                target[action] = reward + self.discount * Q_sa\n",
    "            targets.append(target)\n",
    "\n",
    "        # to numpy matrices\n",
    "        return np.vstack(inputs), np.vstack(targets)\n",
    "\n",
    "    def replay(self, batch_size):\n",
    "        inputs, targets = self._prep_batch(batch_size)\n",
    "        loss = self.Q.train_on_batch(inputs, targets)\n",
    "        return loss\n",
    "\n",
    "    def save(self, fname):\n",
    "        self.Q.save_weights(fname)\n",
    "\n",
    "    def load(self, fname):\n",
    "        self.Q.load_weights(fname)\n",
    "        print(self.Q.get_weights())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You'll see that this is quite similar to the previous Q-learning agent we implemented. There are explore and discount values, for example. But the Q function is now a neural network.\n",
    "\n",
    "The biggest difference are these `remember` and `replay` methods.\n",
    "\n",
    "A challenge with DQNs is that they can be unstable - in particular, they exhibit a problem known as _catastrophic forgetting_ in which later experiences overwrite earlier ones. When this happens, the agent is unable to take full advantage of everything it's learned, only what it's learned most recently.\n",
    "\n",
    "A method to deal with this is called _experience replay_. We just store experienced states and their resulting rewards (as \"memories\"), then between actions we sample a batch of these memories (this is what the `_prep_batch` method does) and use them to train the neural network (i.e. \"replay\" these remembered experiences). This will become clearer in the code below, where we actually train the agent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training...\n",
      "epoch: 0001/6500 | loss: 0.047 | win rate: 0.000\n",
      "epoch: 0101/6500 | loss: 0.293 | win rate: 0.000\n",
      "epoch: 0201/6500 | loss: 0.292 | win rate: 0.000\n",
      "epoch: 0301/6500 | loss: 0.285 | win rate: 0.000\n",
      "epoch: 0401/6500 | loss: 0.254 | win rate: 0.000\n",
      "epoch: 0501/6500 | loss: 0.249 | win rate: 0.000\n",
      "epoch: 0601/6500 | loss: 0.244 | win rate: 0.000\n",
      "epoch: 0701/6500 | loss: 0.219 | win rate: 0.000\n",
      "epoch: 0801/6500 | loss: 0.182 | win rate: 0.000\n",
      "epoch: 0901/6500 | loss: 0.159 | win rate: 0.000\n",
      "epoch: 1001/6500 | loss: 0.145 | win rate: 0.000\n",
      "epoch: 1101/6500 | loss: 0.163 | win rate: 0.000\n",
      "epoch: 1201/6500 | loss: 0.170 | win rate: 0.000\n",
      "epoch: 1301/6500 | loss: 0.151 | win rate: 0.000\n",
      "epoch: 1401/6500 | loss: 0.179 | win rate: 0.000\n",
      "epoch: 1501/6500 | loss: 0.143 | win rate: 0.000\n",
      "epoch: 1601/6500 | loss: 0.149 | win rate: 0.000\n",
      "epoch: 1701/6500 | loss: 0.162 | win rate: 0.000\n",
      "epoch: 1801/6500 | loss: 0.173 | win rate: 0.000\n",
      "epoch: 1901/6500 | loss: 0.154 | win rate: 0.000\n",
      "epoch: 2001/6500 | loss: 0.153 | win rate: 0.000\n",
      "epoch: 2101/6500 | loss: 0.141 | win rate: 0.000\n",
      "epoch: 2201/6500 | loss: 0.160 | win rate: 0.000\n",
      "epoch: 2301/6500 | loss: 0.171 | win rate: 0.000\n",
      "epoch: 2401/6500 | loss: 0.160 | win rate: 0.000\n",
      "epoch: 2501/6500 | loss: 0.167 | win rate: 0.000\n",
      "epoch: 2601/6500 | loss: 0.170 | win rate: 0.000\n",
      "epoch: 2701/6500 | loss: 0.166 | win rate: 0.000\n",
      "epoch: 2801/6500 | loss: 0.131 | win rate: 0.000\n",
      "epoch: 2901/6500 | loss: 0.115 | win rate: 0.000\n",
      "epoch: 3001/6500 | loss: 0.111 | win rate: 0.000\n",
      "epoch: 3101/6500 | loss: 0.092 | win rate: 0.000\n",
      "epoch: 3201/6500 | loss: 0.110 | win rate: 0.000\n",
      "epoch: 3301/6500 | loss: 0.092 | win rate: 0.000\n",
      "epoch: 3401/6500 | loss: 0.097 | win rate: 0.000\n",
      "epoch: 3501/6500 | loss: 0.086 | win rate: 0.000\n",
      "epoch: 3601/6500 | loss: 0.106 | win rate: 0.000\n",
      "epoch: 3701/6500 | loss: 0.082 | win rate: 0.000\n",
      "epoch: 3801/6500 | loss: 0.088 | win rate: 0.000\n",
      "epoch: 3901/6500 | loss: 0.092 | win rate: 0.000\n",
      "epoch: 4001/6500 | loss: 0.082 | win rate: 0.000\n",
      "epoch: 4101/6500 | loss: 0.094 | win rate: 0.000\n",
      "epoch: 4201/6500 | loss: 0.089 | win rate: 0.000\n",
      "epoch: 4301/6500 | loss: 0.091 | win rate: 0.000\n",
      "epoch: 4401/6500 | loss: 0.079 | win rate: 0.000\n",
      "epoch: 4501/6500 | loss: 0.073 | win rate: 0.000\n",
      "epoch: 4601/6500 | loss: 0.062 | win rate: 0.000\n",
      "epoch: 4701/6500 | loss: 0.077 | win rate: 0.000\n",
      "epoch: 4801/6500 | loss: 0.047 | win rate: 0.000\n",
      "epoch: 4901/6500 | loss: 0.052 | win rate: 0.000\n",
      "epoch: 5001/6500 | loss: 0.068 | win rate: 0.000\n",
      "epoch: 5101/6500 | loss: 0.045 | win rate: 0.000\n",
      "epoch: 5201/6500 | loss: 0.053 | win rate: 0.000\n",
      "epoch: 5301/6500 | loss: 0.056 | win rate: 0.000\n",
      "epoch: 5401/6500 | loss: 0.040 | win rate: 0.000\n",
      "epoch: 5501/6500 | loss: 0.036 | win rate: 0.000\n",
      "epoch: 5601/6500 | loss: 0.050 | win rate: 0.000\n",
      "epoch: 5701/6500 | loss: 0.034 | win rate: 0.000\n",
      "epoch: 5801/6500 | loss: 0.052 | win rate: 0.000\n",
      "epoch: 5901/6500 | loss: 0.045 | win rate: 0.000\n",
      "epoch: 6001/6500 | loss: 0.049 | win rate: 0.000\n",
      "epoch: 6101/6500 | loss: 0.034 | win rate: 0.000\n",
      "epoch: 6201/6500 | loss: 0.043 | win rate: 0.000\n",
      "epoch: 6301/6500 | loss: 0.035 | win rate: 0.000\n",
      "epoch: 6401/6500 | loss: 0.042 | win rate: 0.000\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "from time import sleep\n",
    "game = Game()\n",
    "agent = Agent(game)\n",
    "\n",
    "print('training...')\n",
    "epochs = 6500\n",
    "batch_size = 256\n",
    "fname = 'game_weights.h5'\n",
    "\n",
    "# keep track of past record_len results\n",
    "record_len = 100\n",
    "record = deque([], record_len)\n",
    "\n",
    "for i in range(epochs):\n",
    "    game.reset()\n",
    "    reward = 0\n",
    "    loss = 0\n",
    "    # rewards only given at end of game\n",
    "    while reward == 0:\n",
    "        prev_state = game.state\n",
    "        action = agent.choose_action()\n",
    "        game.move(action)\n",
    "        reward = game.update()\n",
    "        new_state = game.state\n",
    "        \n",
    "        # debug, \"this should never happen\"\n",
    "        assert not np.array_equal(new_state, prev_state)\n",
    "\n",
    "        agent.remember(prev_state, action, new_state, reward)\n",
    "        loss += agent.replay(batch_size)\n",
    "\n",
    "    # if running in a terminal, use these instead of print:\n",
    "    #sys.stdout.flush()\n",
    "    #sys.stdout.write('epoch: {:04d}/{} | loss: {:.3f} | win rate: {:.3f}\\r'.format(i+1, epochs, loss, sum(record)/len(record) if record else 0))\n",
    "    if i % 100 == 0:\n",
    "        print('epoch: {:04d}/{} | loss: {:.3f} | win rate: {:.3f}\\r'.format(i+1, epochs, loss, sum(record)/len(record) if record else 0))\n",
    "    \n",
    "    record.append(reward if reward == 1 else 0)\n",
    "\n",
    "agent.save(fname)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we train the agent for 6500 epochs (that is, 6500 games). We also keep a trailing record of its wins to see if its win rate is improving.\n",
    "\n",
    "A game goes on until the reward is non-zero, which means the agent has either lost (reward of -1) or won (reward of +1). Note that between each action the agent \"remembers\" the states and reward it just saw, as well as the action it took. Then it \"replays\" past experiences to update its neural network.\n",
    "\n",
    "Once the agent is trained, we can play a round and see if it wins.\n",
    "\n",
    "Depicted below are the results from my training:\n",
    "\n",
    "![](../assets/dqn_training.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "winner!\n"
     ]
    }
   ],
   "source": [
    "# play a round\n",
    "game.reset()\n",
    "#game.render() # rendering won't work inside a notebook, only from terminal. uncomment\n",
    "reward = 0\n",
    "while reward == 0:\n",
    "    action = agent.choose_action()\n",
    "    game.move(action)\n",
    "    reward = game.update()\n",
    "    #game.render()\n",
    "    sleep(0.1)\n",
    "print('winner!' if reward == 1 else 'loser!')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After 6500 epochs, the agent I trained won about 90% of the time. Not bad from the 30% or so it started at!\n",
    "\n",
    "![Winner](../assets/catch_game_trained.gif)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [Root]",
   "language": "python",
   "name": "Python [Root]"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
